import torch
import torch.nn as nn
import torch.nn.functional as F


class multi_dRNN_with_dilations(nn.Module):
    def __init__(self, hidden_structs, dilations, input_dims, device):
        super(multi_dRNN_with_dilations, self).__init__()
        
        self.device = device
        self.hidden_structs = hidden_structs
        self.dilations = dilations
        
        self.cells = nn.ModuleList()
        last_hidden_dim = input_dims
        for hidden_dims in hidden_structs:
            cell = nn.LSTMCell(last_hidden_dim, hidden_dims)
            self.cells.append(cell)
            last_hidden_dim = hidden_dims
    
    def forward(self, x):
        # x shape: [seq_len, bsz, fea_dim]
        seq_len, bsz, _ = x.size()
        inputs = [x[i] for i in range(seq_len)]  # List of [bsz, fea_dim] tensors
        
        for cell, dilation in zip(self.cells, self.dilations):
            inputs = self.dRNN(cell, inputs, dilation)
        
        # Stack the outputs back to [seq_len, bsz, fea_dim]
        return torch.stack(inputs, dim=0)
    
    def dRNN(self, cell, inputs, rate):
        n_steps = len(inputs)
        batch_size = inputs[0].size(0)
        hidden_size = cell.hidden_size
        
        # Zero padding
        if n_steps % rate:
            padding = rate - (n_steps % rate)
            inputs += [torch.zeros_like(inputs[0]).to(self.device)] * padding
        
        dilated_steps = len(inputs) // rate
        dilated_inputs = [torch.cat(inputs[i*rate : (i+1)*rate], dim=0) for i in range(dilated_steps)]

        hidden, cell_state = self.init_hidden(batch_size*rate, hidden_size)
        dilated_outputs = []
        for dilated_input in dilated_inputs:
            hidden, cell_state = cell(dilated_input, (hidden, cell_state))
            dilated_outputs.append(hidden)

        splitted_outputs = [torch.chunk(output, rate, 0) for output in dilated_outputs]
        outputs = [output for sublist in splitted_outputs for output in sublist]
        return outputs[:n_steps]
    
    def init_hidden(self, batch_size, hidden_dim):
        return (torch.zeros(batch_size, hidden_dim).to(self.device),
                torch.zeros(batch_size, hidden_dim).to(self.device))

class ClusteringLayer(nn.Module):
    def __init__(self, n_clusters, dims, student_tdist, alpha, config):
        super(ClusteringLayer, self).__init__()
        self.n_clusters = n_clusters
        self.cluster_centers = nn.Parameter(torch.Tensor(n_clusters, dims), requires_grad=True)
        self.student_tdist = student_tdist
        self.alpha = alpha
        
        # He initialization
        if config.get('init_method') == 'he':
            nn.init.kaiming_normal_(self.cluster_centers.data, mode='fan_in', nonlinearity='relu')

        # Xavier initialization
        elif config.get('init_method') == 'xavier':
            nn.init.xavier_normal_(self.cluster_centers.data)

        # Generic initialization (for other methods)
        elif config.get('init_method') == 'generic':
            std_dev = config.get('std_dev', 1.0)  # Default standard deviation
            mean = config.get('mean', 0.0)  # Default mean
            nn.init.normal_(self.cluster_centers.data, mean=mean, std=std_dev)

    def forward(self, x):
        # x shape: [batch_size, seq_len, hidden_dim]
        # cluster center shape: [n_clusters, dims]
        # Reshape x as [batch_size * seq_len, hidden_dim]
        x_flattened = x.reshape(-1, x.size(2))
        
        if self.student_tdist == False:
            # Initialize an empty tensor to store similarities for each cluster
            similarities = torch.zeros(x_flattened.size(0), self.n_clusters, device=x.device)

            # Calculate cosine similarity for each cluster center
            for i in range(self.n_clusters):
                cluster_center = self.cluster_centers[i].unsqueeze(0)            
                
                similarity = F.cosine_similarity(x_flattened, cluster_center, dim=1)
                # bound similarity range as [0, 1]
                similarity = (similarity + 1) / 2
                similarity = torch.clamp(similarity, min=0.0, max=1.0)
                similarities[:, i] = similarity
                
            # Reshape similarity as [batch_size, seq_len, n_clusters]
            similarities = similarities.view(x.size(0), x.size(1), self.n_clusters)
        else:
            # Calculate similarities using Student's t-distribution
            similarities = 1.0 / (1.0 + torch.sum(
                torch.pow(x_flattened.unsqueeze(1) - self.cluster_centers, 2), 2) / self.alpha)
            similarities = similarities.pow((self.alpha + 1.0) / 2.0)
            similarities = (similarities.t() / torch.sum(similarities, 1)).t()

            # Reshape q back to [(]batch_size, seq_len, n_clusters]
            similarities = similarities.reshape(x.shape[0], x.shape[1], -1)
        return similarities

class Time_Series_Deep_SVDD(nn.Module):
    def __init__(self, config):
        super(Time_Series_Deep_SVDD, self).__init__()
        self.hidden_structs=config['hidden_structs']
        self.dilations=config['dilations']
        self.input_dims=config['feature_size']
        self.cluster_num=config['cluster_num']
        self.hidden_dim=config['hidden_structs'][-1]
        self.device=config['device']
        
        self.student_tdist = False
        self.alpha = 1.0
        
        if self.cluster_num >= 2:
            self.student_tdist = True
            
        # Encoder setup remains the same.
        self.encoder = multi_dRNN_with_dilations(
                hidden_structs = self.hidden_structs, 
                dilations = self.dilations,  
                input_dims = self.input_dims,  
                device = self.device,
        )

        self.centroid = ClusteringLayer(self.cluster_num, self.hidden_dim, self.student_tdist, self.alpha, config)
        # Initialize threshold as a learnable parameter for each cluster
        self.threshold = nn.Parameter(torch.full((self.cluster_num,), config['init_threshold']), requires_grad=True)

    def forward(self, x):
        # np.shape(x): torch.Size([batch_size, seq_len, fea_dim])
        reshape_x = x.permute(1,0,2)      # torch.Size([seq_len, batch_size, fea_dim])
        x_latent = self.encoder(reshape_x)      # torch.Size([seq_len, batch_size, hidden_dim)
        reshape_x_latent = x_latent.permute(1,0,2)      # torch.Size([batch_size, seq_len, hidden_dim)
        
        if self.student_tdist == False:
            # Compute the similarity for the clustering layer
            q = self.centroid(reshape_x_latent)
            # Apply threshold for each cluster
            p = (q >= self.threshold.unsqueeze(0).unsqueeze(0)).float()
            
        else:
            # Compute the student's t-value for the clustering layer
            q = self.centroid(reshape_x_latent)
            weight = q**2 / q.sum(dim=(0, 1))
            p = weight / weight.sum(dim=2, keepdim=True)
        return reshape_x_latent, p, q, self.centroid.cluster_centers, self.threshold

    def get_model_params(self):
        # Return model parameters excluding 'threshold'
        return [param for name, param in self.named_parameters() if name != 'threshold']

    def get_thre_param(self):
        # Return 'threshold' parameter
        return [self.threshold]
